import json
from time import time
from typing import Optional, Any, Dict, List, Union, Sequence
from pydantic import BaseModel, ConfigDict, Field
from agentica.media import AudioResponse
from agentica.utils.log import logger


class MessageReferences(BaseModel):
    """The references added to user message for RAG"""

    # The query used to retrieve the references.
    query: str
    # References (from the vector database or function calls)
    references: Optional[List[Dict[str, Any]]] = None
    # Time taken to retrieve the references.
    time: Optional[float] = None


class Message(BaseModel):
    """Message sent to the Model"""

    # The role of the message author.
    # One of system, user, assistant, or tool.
    role: str
    # The contents of the message. content is required for all messages,
    # and may be null for assistant messages with function calls.
    content: Optional[Union[List[Any], str]] = None
    # An optional name for the participant.
    # Provides the model information to differentiate between participants of the same role.
    name: Optional[str] = None
    # Tool call that this message is responding to.
    tool_call_id: Optional[str] = None
    # The tool calls generated by the model, such as function calls.
    tool_calls: Optional[List[Dict[str, Any]]] = None

    # Additional modalities
    audio: Optional[Any] = None
    images: Optional[Sequence[Any]] = None
    videos: Optional[Sequence[Any]] = None

    # Output from the models
    audio_output: Optional[AudioResponse] = None

    # The thinking content from the model
    thinking: Optional[str] = None
    redacted_thinking: Optional[str] = None

    # Data from the provider we might need on subsequent messages
    provider_data: Optional[Dict[str, Any]] = None

    # --- Data not sent to the Model API ---
    # The reasoning content from the model
    reasoning_content: Optional[str] = None
    # -*- Attributes not sent to the model
    # The name of the tool called
    tool_name: Optional[str] = Field(None, alias="tool_call_name")
    # Arguments passed to the tool
    tool_args: Optional[Any] = Field(None, alias="tool_call_arguments")
    # The error of the tool call
    tool_call_error: Optional[bool] = None
    # If True, the agent will stop executing after this tool call.
    stop_after_tool_call: bool = False

    # Metrics for the message. This is not sent to the Model API.
    metrics: Dict[str, Any] = Field(default_factory=dict)

    # The references added to the message for RAG
    references: Optional[MessageReferences] = None

    # The Unix timestamp the message was created.
    created_at: int = Field(default_factory=lambda: int(time()))

    model_config = ConfigDict(extra="allow", populate_by_name=True, arbitrary_types_allowed=True)

    def get_content_string(self) -> str:
        """Returns the content as a string."""
        if isinstance(self.content, str):
            return self.content
        if isinstance(self.content, list):
            if len(self.content) > 0 and isinstance(self.content[0], dict) and "text" in self.content[0]:
                return self.content[0].get("text", "")
            else:
                return json.dumps(self.content, ensure_ascii=False)
        return ""

    def to_dict(self) -> Dict[str, Any]:
        _dict = self.model_dump(
            exclude_none=True,
            include={"role", "content", "audio", "name", "tool_call_id", "tool_calls"},
        )
        # Manually add the content field even if it is None
        if self.content is None:
            _dict["content"] = None

        return _dict

    def log(self, level: Optional[str] = None):
        """Log the message to the console

        @param level: The level to log the message at. One of debug, info, warning, or error.
            Defaults to debug.
        """
        level = level or "debug"
        level = level.lower()
        if level == "debug":
            _logger = logger.debug
        elif level == "info":
            _logger = logger.info
        elif level == "warning":
            _logger = logger.warning
        elif level == "error":
            _logger = logger.error
        else:
            _logger = logger.debug

        _logger(f"============== {self.role} ==============")
        if self.name:
            _logger(f"Name: {self.name}")
        if self.tool_call_id:
            _logger(f"Tool call Id: {self.tool_call_id}")
        if self.content:
            if isinstance(self.content, str) or isinstance(self.content, list):
                _logger(self.content)
            elif isinstance(self.content, dict):
                _logger(json.dumps(self.content, ensure_ascii=False))
        if self.tool_calls:
            _logger(f"Tool Calls: {json.dumps(self.tool_calls, ensure_ascii=False)}")
        if self.images:
            _logger(f"Images added: {len(self.images)}")
        if self.videos:
            _logger(f"Videos added: {len(self.videos)}")
        if self.audio:
            if isinstance(self.audio, dict):
                _logger(f"Audio files added: {len(self.audio)}")
                if "id" in self.audio:
                    _logger(f"Audio ID: {self.audio['id']}")
                elif "data" in self.audio:
                    _logger("Message contains raw audio data")
            else:
                _logger(f"Audio file added: {self.audio}")

    def content_is_valid(self) -> bool:
        """Check if the message content is valid."""

        return self.content is not None and len(self.content) > 0


class SystemMessage(Message):
    """System message to the model"""
    role: str = "system"


class UserMessage(Message):
    """User message to the model"""
    role: str = "user"


class AssistantMessage(Message):
    """Assistant message from the model"""
    role: str = "assistant"


class ToolMessage(Message):
    """Tool/Function message"""
    role: str = "tool"
